import copy
import os
import time
from dataclasses import dataclass

import mindspore
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
import mindspore.nn as nn
import numpy as np
from mindspore import Tensor, context, mutable
from mindspore import numpy as msnp
from mindspore import ops

#from trlx.utils import Clock
from mindspore.dataset import GeneratorDataset, MindDataset
from mindspore.ops import functional as F
from mindspore.ops import operations as P

from mindformers import AutoTokenizer
from ppo_models import CausalLMHydraWithValueHead, PPO_model, PPOConfig
from reward_model import CriticModel, RewardModel
from utils import IsFirstStage, IsLastStage, set_pipeline_parallel_context


@dataclass
class PPORLElement:
    query_tensor: Tensor
    response_tensor: Tensor
    logprobs: Tensor
    values: Tensor
    rewards: Tensor
    advantages: Tensor
    returns: Tensor
    pretrain_ids: Tensor


def get_first_diverge_indices(preferred_comp_ids,  # shape = batch_size * seq_length
                              disfavored_comp_ids  # shape = batch_size * seq_length
                              ):
    is_equal = Tensor(preferred_comp_ids == disfavored_comp_ids).astype('float32')
    print("is_equal is: ", is_equal)
    first_diverge_indices = is_equal.sum(axis=1, dtype=mindspore.int32)
    return first_diverge_indices


class RewardFn(nn.Cell):
    def __init__(self, model_config):
        super(RewardFn, self).__init__()

        self.ckpt_path = model_config.checkpoint_name_or_path
        print("RewardFn.ckpt_path: ", self.ckpt_path)
        model_config.checkpoint_name_or_path = ""

        self.pad_token = model_config.pad_token_id
        self.reward_model = RewardModel(model_config)
        self.not_equal = P.NotEqual()

        if self.ckpt_path:
            param_dict = mindspore.load_checkpoint(self.ckpt_path)
            print("=====begin to load reward model ckpt from: ", self.ckpt_path, flush=True)
            param_not_load, ckpt_not_load = mindspore.load_param_into_net(self.reward_model, param_dict)
            print("parameter not loaded: ", param_not_load, flush=True)
            print("ckpt not loaded: ", ckpt_not_load, flush=True)

    def get_scores(self, samples):
        attn_masks = self.not_equal(samples, self.pad_token).astype(mstype.float32)
        end_indices = (attn_masks.sum(axis=1) - 1).to(mstype.int32)
        bs_scores = self.reward_model.infer(samples, end_indices)
        return bs_scores, end_indices

    def construct(self, samples, original_samples):
        original_scores, _ = self.get_scores(original_samples)
        scores, _ = self.get_scores(samples)
        norms_scores = scores - original_scores
        # return scores, original_scores, norms_scores
        return norms_scores


class AcceleratePPOTrainer:
    # reward_fn: Callable[[List[str], List[str], List[str]], List[float]]
    # tokenizer: AutoTokenizer
    def __init__(self,
                 ppo_config=None,
                 sft_model_config=None,
                 ref_model_config=None,
                 critic_model_config=None,
                 rm_model_config=None,
                 opt=None):
        self.mind_dataset_dir = opt.mind_dataset_dir
        columns_to_project = ["prompt_ids", "original_sample_ids", "pretrain_ids"]
        mindspore.dataset.config.set_seed(2023)
        dataset = MindDataset(self.mind_dataset_dir).project(columns=columns_to_project)
        self.prompt_dataloader = dataset.take(ppo_config.num_rollouts)  # ?
        self.prompt_dataloader = self.prompt_dataloader.batch(batch_size=ppo_config.chunk_size
                                                              * sft_model_config.parallel_config.data_parallel)
        self.prompt_iterator = self.prompt_dataloader.create_tuple_iterator()
        self.ppo_config = ppo_config
        self.sft_model_config = sft_model_config
        self.rm_model_config = rm_model_config
        self.opt = opt
        current_path = os.getenv("RLHF_ROOT_DIR")
        if current_path is None:
            raise ValueError(f"Please run `source env.sh` before running the program.")
        self.tokenizer = AutoTokenizer.from_pretrained(current_path + "/gpt2")
        print("self.tokenizer.pad_token_id", self.tokenizer.pad_token_id)
        print("self.tokenizer.eos_token_id", self.tokenizer.eos_token_id)

        policy_model = CausalLMHydraWithValueHead(sft_model_config, self.ppo_config)
        critic_model = CriticModel(critic_model_config)
        self.ppo_model = PPO_model(ppo_config, policy_model, critic_model, self.opt)

        self.ref_model = CausalLMHydraWithValueHead(ref_model_config, self.ppo_config)
        self.ref_model.model.set_train(False)

        self.ref_mean = 0
        self.ref_std = 0
        self.cliprange_reward = 10.0
        self.store = []

        self.reward_fn = RewardFn(rm_model_config)
        self.reward_fn.set_train(False)
        self.reward_fn.reward_model.set_train(False)
        self.reward_fn.reward_model.model.set_train(False)

        self.log_softmax = P.LogSoftmax(axis=-1)
        self.gather = P.GatherD()
        self.unsqueeze = P.ExpandDims()
        self.squeeze = P.Squeeze(axis=-1)
        self.depend = P.Depend()

    def push_to_store(self, data):
        self.store = data

    def generate(self, input_ids, attn_masks=None):
        input_ids_list = input_ids.asnumpy().tolist()
        prompt_len = (np.array(input_ids_list) != self.ppo_config.pad_token_id).astype(int).sum(1)
        left_padding_prompt = np.ones((len(input_ids_list),
                                       self.ppo_config.max_prompt_length)) * self.ppo_config.pad_token_id
        resposne_array = np.ones((len(input_ids_list), self.ppo_config.max_decode_length)) * \
            self.ppo_config.pad_token_id
        samples = np.ones((len(input_ids_list), self.ppo_config.seq_length)) * self.ppo_config.pad_token_id

        generate_begin_time = time.time()
        outputs = self.ppo_model.generate(input_ids_list)
        print("Generating elapsed time: ", time.time() - generate_begin_time, flush=True)

        for i in range(len(input_ids_list)):
            x = outputs[i][prompt_len[i]: prompt_len[i] + self.ppo_config.max_decode_length]
            resposne_array[i, :len(x)] = x
            p = outputs[i]
            samples[i, :len(p)] = p
            left_padding_prompt[i, self.ppo_config.max_prompt_length -
                                prompt_len[i]:] = input_ids_list[i][:prompt_len[i]]
        return Tensor(
            samples, mstype.int32), Tensor(
            resposne_array, mstype.int32), Tensor(
            left_padding_prompt, mstype.int32)

    def partition(self, prompt_tensors, samples):
        n_samples: int = samples.shape[0]
        response_tensors = []
        for ix in range(n_samples):
            # get the start_idx of the response in `prompt_tensors`,
            # where `prompt_tensors` is the concatenated prompt and response
            start = np.max(np.nonzero(np.not_equal(prompt_tensors[ix], self.ppo_config.pad_token_id))) + 1
            response_tensors.append(samples[ix, start: int(start + self.ppo_config.max_decode_length)])
        return response_tensors

    def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
        self.ppo_model.policy_model.model.set_train(False)
        self.ppo_model.critic_model.model.set_train(False)
        self.ref_model.model.set_train(False)
        self.reward_fn.reward_model.set_train(False)
        ppo_rl_elements = []
        while len(ppo_rl_elements) < num_rollouts:
            rollout_total = time.time()
            try:
                batch = next(self.prompt_iterator)
            except StopIteration:
                mindspore.dataset.config.set_seed(2023)
                self.prompt_iterator = self.prompt_dataloader.create_tuple_iterator()
                batch = next(self.prompt_iterator)

            # batch[0]: prompt, right padding to max_prompt_length=1024
            batch_0 = batch[0][:, :512]
            batch_1 = batch[1][:, :1024]
            batch_2 = batch[2][:, :1024]
            prompt_tensors = Tensor(batch_0, mstype.int32)
            pretrain_ids = Tensor(batch_2, mstype.int32)

            self.ppo_model.policy_model.model.add_flags_recursive(use_past=self.opt.use_past)
            # ========================= Generate ======================
            generate_start = time.time()
            samples, resposne_array, left_padding_prompt = self.generate(prompt_tensors)
            generate_end = time.time()
            # =========================================================
            samples = samples.asnumpy()
            resposne_array = resposne_array.asnumpy()
            left_padding_prompt = left_padding_prompt.asnumpy()
            self.ppo_model.policy_model.model.add_flags_recursive(use_past=False)
            print("================== Finish Generating ===============", flush=True)
            # print("prompt: ", flush=True)
            # print("===== 1 \n", self.tokenizer.decode(prompt_tensors[0].asnumpy(), skip_special_tokens=True), flush=True)
            # print("===== 2 \n", self.tokenizer.decode(prompt_tensors[1].asnumpy(), skip_special_tokens=True), flush=True)
            # print("===== 3 \n", self.tokenizer.decode(prompt_tensors[2].asnumpy(), skip_special_tokens=True), flush=True)
            # print("===== 4 \n", self.tokenizer.decode(prompt_tensors[3].asnumpy(), skip_special_tokens=True), flush=True)

            '''print("prompt+generated response: ", flush=True)
            print("===== 1 \n", self.tokenizer.decode(samples[0], skip_special_tokens=True), flush=True)
            print("===== 2 \n", self.tokenizer.decode(samples[1], skip_special_tokens=True), flush=True)
            print("===== 3 \n", self.tokenizer.decode(samples[2], skip_special_tokens=True), flush=True)
            print("===== 4 \n", self.tokenizer.decode(samples[3], skip_special_tokens=True), flush=True)'''

            # print("original samples: ", flush=True)
            # print("===== 1 \n", self.tokenizer.decode(batch[1][0].asnumpy(), skip_special_tokens=True), flush=True)
            # print("===== 2 \n", self.tokenizer.decode(batch[1][1].asnumpy(), skip_special_tokens=True), flush=True)
            # print("===== 3 \n", self.tokenizer.decode(batch[1][2].asnumpy(), skip_special_tokens=True), flush=True)
            # print("===== 4 \n", self.tokenizer.decode(batch[1][3].asnumpy(), skip_special_tokens=True), flush=True)

            # samples: prompt + generated response, right padding to seq_length=2048
            # original_samples/batch[1]: prompt + reference response, right padding to seq_length=2048
            samples = Tensor(samples, mstype.int32)
            original_samples = Tensor(batch_1, mstype.int32)

            # ====================== Reward model ===========================
            reward_start = time.time()
            scores = self.reward_fn(samples, original_samples=original_samples)
            reward_end = time.time()
            # ===============================================================
            print("scores: \n", scores, flush=True)

            self.ppo_model.policy_model.model.set_train(False)
            self.ref_model.model.set_train(False)

            # all_tokens: [pad, ..., pad, `prompt`, `response`, pad, ..., pad]
            print("left_padding_prompt: ", left_padding_prompt.shape, flush=True)
            print("resposne_array: ", resposne_array.shape, flush=True)
            all_tokens = np.concatenate((left_padding_prompt, resposne_array), axis=1)

            all_tokens = Tensor(all_tokens, mstype.int32)
            all_tokens = self.depend(all_tokens, scores)
            # ======================= Policy Model ================================
            policy_start = time.time()
            logprobs = self.ppo_model.policy_model(all_tokens, batch_valid_length=None,
                                                   is_first_iteration=True, samples=all_tokens)
            policy_end = time.time()
            print("logprob is ", logprobs.shape, flush=True)
            # ====================================================================
            all_tokens = self.depend(all_tokens, logprobs)
            # ======================= Critic Model ================================
            critic_start = time.time()
            values = self.ppo_model.critic_model(all_tokens)
            critic_end = time.time()
            print("values is ", values.shape, flush=True)
            # ====================================================================

            self.ref_model.model.add_flags_recursive(use_past=False)
            # ======================= Reference Model ================================
            ref_start = time.time()
            ref_logprobs = self.ref_model(all_tokens, samples=all_tokens)
            ref_end = time.time()
            print("ref_logprobs is ", ref_logprobs.shape, flush=True)
            # ========================================================================

            logprobs = logprobs.asnumpy()
            values = values.asnumpy()
            ref_logprobs = ref_logprobs.asnumpy()

            values = values[:, :-1]
            n_samples: int = samples.shape[0]

            start = self.ppo_config.max_prompt_length - 1
            end = self.ppo_config.seq_length - 1
            valid_length_response = (samples.asnumpy() != self.ppo_config.pad_token_id).astype(int).sum(1) \
                - (prompt_tensors.asnumpy() != self.ppo_config.pad_token_id).astype(int).sum(1)

            all_values = values[:, start:end]
            all_logprobs = logprobs[:, start:end]

            print("all_values: ", all_values.shape, flush=True)

            kl_divergence_estimate = self.ppo_model.kl_ctl.value.asnumpy() * (logprobs - ref_logprobs)

            kl_divergence_estimate = kl_divergence_estimate[:, start:end]

            rollout_count = 0
            for sample_idx in range(n_samples):
                sample_kl_divergence_estimate = kl_divergence_estimate[sample_idx]

                rewards = sample_kl_divergence_estimate

                # print("===== rewards[int(valid_length_response[sample_idx] - 1)]: ", rewards[int(valid_length_response[sample_idx] - 3): int(valid_length_response[sample_idx] + 1)], flush=True)
                # print("===== valid_length_response: ", valid_length_response[sample_idx], flush=True)

                all_logprobs[sample_idx][int(valid_length_response[sample_idx]):] = 0.0
                all_values[sample_idx][int(valid_length_response[sample_idx]):] = 0.0
                all_values = np.array(all_values).reshape((n_samples, -1))
                rewards[int(valid_length_response[sample_idx]):] = 0.0

                index = valid_length_response[sample_idx] if valid_length_response[sample_idx] < len(rewards) else -1
                print("=====scores type: ", type(scores))
                if isinstance(scores, mindspore.Tensor):
                    scores = scores.asnumpy()
                rewards[int(index) - 1] += scores[sample_idx]
                # print("===== resposne_array[int(valid_length_response[sample_idx] - 1)]: ", resposne_array[sample_idx][int(valid_length_response[sample_idx] - 3): int(valid_length_response[sample_idx] + 1)], flush=True)
                # print("===== rewards[int(valid_length_response[sample_idx] - 1)]: ", rewards[int(valid_length_response[sample_idx] - 3): int(valid_length_response[sample_idx] + 1)], flush=True)

                '''print("===== rewards: ", rewards, flush=True)
                np.save("/home/shiwenqi/mindspore-chatgpt-distributed/data/rewards.npy", rewards)
                print("===== values: ", values, flush=True)
                np.save("/home/shiwenqi/mindspore-chatgpt-distributed/data/values.npy", values)'''

                response_length = len(rewards)
                # print("===== response_length: ", response_length, flush=True)
                lastgaelam = 0
                advantages_reversed = []
                for k in range(response_length):
                    t = response_length - k - 1
                    nextvalues = all_values[sample_idx, t + 1] if t < response_length - 1 else 0.0
                    delta = rewards[t] + self.ppo_model.gamma * nextvalues - all_values[sample_idx, t]
                    lastgaelam = delta + self.ppo_model.gamma * self.ppo_model.lam * lastgaelam
                    advantages_reversed.append(lastgaelam)
                advantages = np.stack(advantages_reversed[::-1])

                returns = advantages + all_values[sample_idx]

                '''print("===== advantages: ", advantages, flush=True)
                np.save("/home/shiwenqi/mindspore-chatgpt-distributed/data/advantages.npy", advantages)
                print("===== returns: ", returns, flush=True)
                np.save("/home/shiwenqi/mindspore-chatgpt-distributed/data/returns.npy", returns)

                exit()'''
                print("===== advantages & returns shape: ", len(advantages), len(returns))
                ppo_rl_elements.append(
                    PPORLElement(
                        query_tensor=prompt_tensors.asnumpy()[sample_idx],
                        # query_tensor=prompt_tensors[sample_idx],
                        response_tensor=all_tokens.asnumpy()[sample_idx],
                        # response_tensor=samples[sample_idx],
                        logprobs=all_logprobs[sample_idx],
                        values=all_values[sample_idx],
                        rewards=rewards,
                        advantages=advantages,
                        returns=returns,
                        pretrain_ids=pretrain_ids.asnumpy()[sample_idx]
                    )
                )

                rollout_count += 1
            rollout_total_end = time.time()
            print("Rollout elapsed time: ", rollout_total_end - rollout_total, flush=True)
            print("Each part of time is ", flush=True)
            print("==============================", flush=True)
            print(f"Generate: {generate_end - generate_start}", flush=True)
            print(f"Reward: {reward_end - reward_start}", flush=True)
            print(f"Policy: {policy_end - policy_start}", flush=True)
            print(f"Critic: {critic_end - critic_start}", flush=True)
            print(f"Reference: {ref_end - ref_start}", flush=True)
            print("==============================", flush=True)
        self.push_to_store(ppo_rl_elements)


if __name__ == "__main__":
    # samples = np.random.randint(low=0, high=15, size=(10, 550)).astype(np.int32)
    # get_scores(samples)
    # reward_fn(samples)
    context.set_context(device_target='Ascend', device_id=1, mode=mindspore.GRAPH_MODE)
    trainer = AcceleratePPOTrainer(ppo_config=PPOConfig)
    trainer.make_experience(num_rollouts=2)
